package paddledjl;

import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.paddlepaddle.zoo.cv.imageclassification.PpWordRotateTranslator;
import ai.djl.paddlepaddle.zoo.cv.objectdetection.PpWordDetectionTranslator;
import ai.djl.paddlepaddle.zoo.cv.wordrecognition.PpWordRecognitionTranslator;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;

import java.nio.file.Paths;
import java.util.concurrent.ConcurrentHashMap;

/**
 * @author 小帅丶
 * @className DJLPaddleOCRConfig
 * @Description DJL-PaddleOCR 配置
 * @Date 2021-12-13-15:47
 **/
public class DJLPaddleOCRConfig {
    /** 引擎名称 */
    private static String ENGINE_NAME = "PaddlePaddle";


    private static String DETECTION_MODEL_PATH = "F:\\testmodel\\det_db.zip";

    private static String DIRECTION_MODEL_PATH = "F:\\testmodel\\cls.zip";

    private static String RECOGNITION_MODEL_PATH = "F:\\testmodel\\rec_crnn.zip";


    private static String DIRECTION_MODEL_PATH_GENERAL = "F:\\testmodel\\ch_ppocr_mobile_v2.0_cls_infer.tar";


    private static String DETECTION_MODEL_PATH_V2_MOBILE = "F:\\testmodel\\ch_ppocr_mobile_v2.0_det_infer.tar";

    private static String RECOGNITION_MODEL_PATH_V2_MOBILE = "F:\\testmodel\\ch_ppocr_mobile_v2.0_rec_infer.tar";


    private static String DETECTION_MODEL_PATH_V2 = "F:\\testmodel\\ch_PP-OCRv2_det_infer.tar";

    private static String RECOGNITION_MODEL_PATH_V2 = "F:\\testmodel\\ch_PP-OCRv2_rec_infer.tar";


    private static String DETECTION_MODEL_PATH_GENERAL = "F:\\testmodel\\ch_ppocr_server_v2.0_det_infer.tar";

    private static String RECOGNITION_MODEL_PATH_GENERAL = "F:\\testmodel\\ch_ppocr_server_v2.0_rec_infer.tar";
    /**
     * @Author 小帅丶
     * @Description 模型版本
     * @Date  2021-12-13 16:50
     * DEFAULT DJL官方文档示例模型版本
     * ch_PP_OCRv2 中英文超轻量PP-OCRv2模型
     * ch_ppocr_server_v2 中英文通用PP-OCR server模型
     * ch_ppocr_mobile_v2 中英文超轻量PP-OCR mobile模型
     **/
    public enum MODEL_VERSION{
        DEFAULT,ch_PP_OCRv2,ch_ppocr_server_v2,ch_ppocr_mobile_v2;
    }


    private enum MODEL_TYPE{
        DETECT,DIRECTION,RECOGNITION;
    }


    /**
     * @Author 小帅丶
     * @Description 检测模型
     * @Date  2021-12-13 15:49
     * @return ai.djl.inference.Predictor<ai.djl.modality.cv.Image,ai.djl.modality.cv.output.DetectedObjects>
     **/
    public static Predictor<Image, DetectedObjects> getDetectionModel () throws Exception{
       return getDetectionModel(MODEL_VERSION.DEFAULT);
    }
    /**
     * @Author 小帅丶
     * @Description 方向分类模型
     * @Date  2021-12-13 15:49
     * @return ai.djl.inference.Predictor<ai.djl.modality.cv.Image,ai.djl.modality.cv.output.DetectedObjects>
     **/
    public static Predictor<Image, Classifications> getDirectionModel () throws Exception{
        return getDirectionModel(MODEL_VERSION.DEFAULT);
    }
    /**
     * @Author 小帅丶
     * @Description 识别模型
     * @Date  2021-12-13 15:49
     * @return ai.djl.inference.Predictor<ai.djl.modality.cv.Image,ai.djl.modality.cv.output.DetectedObjects>
     **/
    public static Predictor<Image, String> getRecognitionModel () throws Exception{
        return getRecognitionModel(MODEL_VERSION.DEFAULT);
    }
    /**
     * @Author 小帅丶
     * @Description 检测模型
     * @Date  2021-12-13 15:49
     * @return ai.djl.inference.Predictor<ai.djl.modality.cv.Image,ai.djl.modality.cv.output.DetectedObjects>
     **/
    public static Predictor<Image, DetectedObjects> getDetectionModel (MODEL_VERSION model_version) throws Exception{
        String MODEL_PATH = getModelPath(model_version,MODEL_TYPE.DETECT);
        Criteria<Image, DetectedObjects> build = Criteria.builder()
                .optEngine(ENGINE_NAME)
                .setTypes(Image.class, DetectedObjects.class)
                .optModelPath(Paths.get(MODEL_PATH))
                //.optModelUrls("https://resources.djl.ai/test-models/paddleOCR/mobile/det_db.zip")
                .optTranslator(new PpWordDetectionTranslator(new ConcurrentHashMap<>()))
                .build();
        ZooModel<Image, DetectedObjects> zooModel = build.loadModel();
        Predictor<Image, DetectedObjects> detector = zooModel.newPredictor();
        return detector;
    }

    /**
     * @Author 小帅丶
     * @Description 方向分类模型
     * @Date  2021-12-13 15:49
     * @return ai.djl.inference.Predictor<ai.djl.modality.cv.Image,ai.djl.modality.cv.output.DetectedObjects>
     **/
    public static Predictor<Image, Classifications> getDirectionModel (MODEL_VERSION model_version) throws Exception{
        String MODEL_PATH = getModelPath(model_version,MODEL_TYPE.DIRECTION);
        Criteria<Image, Classifications> build = Criteria.builder()
                .optEngine(ENGINE_NAME)
                .setTypes(Image.class, Classifications.class)
                .optModelPath(Paths.get(MODEL_PATH))
                //.optModelUrls("https://resources.djl.ai/test-models/paddleOCR/mobile/det_db.zip")
                .optTranslator(new PpWordRotateTranslator())
                .build();
        ZooModel<Image, Classifications> zooModel = build.loadModel();
        Predictor<Image, Classifications> predictor = zooModel.newPredictor();
        return predictor;
    }

    /**
     * @Author 小帅丶
     * @Description 识别模型
     * @Date  2021-12-13 15:49
     * @return ai.djl.inference.Predictor<ai.djl.modality.cv.Image,ai.djl.modality.cv.output.DetectedObjects>
     **/
    public static Predictor<Image, String> getRecognitionModel (MODEL_VERSION model_version) throws Exception{
        String MODEL_PATH = getModelPath(model_version,MODEL_TYPE.RECOGNITION);
        Criteria<Image, String> criteria = Criteria.builder()
                .optEngine(ENGINE_NAME)
                .setTypes(Image.class, String.class)
                .optModelPath(Paths.get(MODEL_PATH))
                //.optModelUrls("https://resources.djl.ai/test-models/paddleOCR/mobile/rec_crnn.zip")
                .optTranslator(new PpWordRecognitionTranslator())
                .build();
        ZooModel<Image, String> zooModel = criteria.loadModel();
        Predictor<Image, String> predictor = zooModel.newPredictor();
        return predictor;
    }

    /**
     * @Author 小帅丶
     * @Description 获取版本
     * @Date  2021-12-13 16:45
     * @param model_version - 模型版本
     * @param model_type - 模型类型
     * @return java.lang.String
     **/
    private static String getModelPath(MODEL_VERSION model_version, MODEL_TYPE model_type) {
        String MODEL_PATH = "";
        if(MODEL_TYPE.DETECT.equals(model_type)){
            if(MODEL_VERSION.DEFAULT.equals(model_version)){
                MODEL_PATH = DETECTION_MODEL_PATH;
            }
            if(MODEL_VERSION.ch_PP_OCRv2.equals(model_version)){
                MODEL_PATH = DETECTION_MODEL_PATH_V2;
            }
            if(MODEL_VERSION.ch_ppocr_mobile_v2.equals(model_version)){
                MODEL_PATH = DETECTION_MODEL_PATH_V2_MOBILE;
            }
            if(MODEL_VERSION.ch_ppocr_server_v2.equals(model_version)){
                MODEL_PATH = DETECTION_MODEL_PATH_GENERAL;
            }
        }
        if(MODEL_TYPE.DIRECTION.equals(model_type)){
            if(MODEL_VERSION.DEFAULT.equals(model_version)){
                MODEL_PATH = DIRECTION_MODEL_PATH;
            }else{
                MODEL_PATH = DIRECTION_MODEL_PATH_GENERAL;
            }
        }
        if(MODEL_TYPE.RECOGNITION.equals(model_type)){
            if(MODEL_VERSION.DEFAULT.equals(model_version)){
                MODEL_PATH = RECOGNITION_MODEL_PATH;
            }
            if(MODEL_VERSION.ch_ppocr_server_v2.equals(model_version)){
                MODEL_PATH = RECOGNITION_MODEL_PATH_GENERAL;
            }
            if(MODEL_VERSION.ch_PP_OCRv2.equals(model_version)){
                MODEL_PATH = RECOGNITION_MODEL_PATH_V2;
            }
            if(MODEL_VERSION.ch_ppocr_mobile_v2.equals(model_version)){
                MODEL_PATH = RECOGNITION_MODEL_PATH_V2_MOBILE;
            }
        }
        return MODEL_PATH;
    }
}
